#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import os

from gen_elemwise_utils import ARITIES, DTYPES, MODES


def main():
    parser = argparse.ArgumentParser(
        description="generate elemwise impl files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--type",
        type=str,
        choices=["cuda", "hip", "cpp"],
        default="cpp",
        help="generate cuda/hip kernel file",
    )
    parser.add_argument("output", help="output directory")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    if args.type == "cuda":
        cpp_ext = "cu"
    elif args.type == "hip":
        cpp_ext = "cpp.hip"
    else:
        assert args.type == "cpp"
        cpp_ext = "cpp"

    for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
        for mode in MODES[(anum, DTYPES[ctype][1])]:
            formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
            fname = "{}_{}.{}".format(mode, ctype, cpp_ext)
            fname = os.path.join(args.output, fname)
            with open(fname, "w") as fout:
                w = lambda s: print(s, file=fout)
                w("// generated by gen_elemwise_kern_impls.py")

                if ctype == "dt_float16" or ctype == "dt_bfloat16":
                    w("#if !MEGDNN_DISABLE_FLOAT16")

                w("#define KERN_IMPL_MODE(cb) {}".format(formode))
                w("#define KERN_IMPL_ARITY {}".format(anum))
                w("#define KERN_IMPL_CTYPE {}".format(ctype))
                w('#include "../kern_impl.inl"')

                if ctype == "dt_float16" or ctype == "dt_bfloat16":
                    w("#endif")

            print("generated {}".format(fname))

    os.utime(args.output)


if __name__ == "__main__":
    main()
